{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook contains code to train a linear classifier on MNIST. The principal changes from the previous notebook are:\n",
    "\n",
    "* We have switched from regression to classification.\n",
    "\n",
    "* We are using a different loss function. Instead of using squared error, we will now use cross-entropy.\n",
    "\n",
    "* We are using a new dataset. MNIST contains 28x28 pixel handwritten digits.\n",
    "\n",
    "An important takeaway: notice that despite these changes, the line that creates the gradient descent optimizer is identical to the previous notebook. This is the magic of automatic differentiation. Once we've specified our graph and the loss function, TensorFlow is able to analyze it for us, and determine how to adjust our variables to decrease the loss.\n",
    "\n",
    "The model we train here is unimpressive in terms of accuracy. The goal is to introduce you to the dataset. At the end is a short exercise.\n",
    "\n",
    "Experiment with this notebook by running the cells and uncommenting code when asked. \n",
    "\n",
    "When you've finished with this notebook, move on to the next one which will modify our linear classifier into a deep neural network, and adds code to visualize the graph in TensorBoard."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import pylab\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.Session()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the MNIST dataset. \n",
    "# It will be downloaded to '/tmp/data' if you don't already have a local copy.\n",
    "mnist = input_data.read_data_sets('/tmp/data', one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Uncomment these lines to understand the format of the dataset.\n",
    "\n",
    "# 1. There are 55k, 5k, and 10k examples in train, validation, and test.\n",
    "# print ('Train, validation, test: %d, %d, %d' % \n",
    "#       (len(mnist.train.images), len(mnist.validation.images), len(mnist.test.images)))\n",
    "\n",
    "# 2. The format of the labels is 'one-hot'.\n",
    "# The fifth image happens to be a '6'.\n",
    "# This is represented as '[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]'\n",
    "# print (mnist.train.labels[4])\n",
    "\n",
    "# You can find the index of the label, like this:\n",
    "# print (np.argmax(mnist.train.labels[4]))\n",
    "\n",
    "# 3. An image is a 'flattened' array of 28*28 = 784 pixels.\n",
    "# print (len(mnist.train.images[4]))\n",
    "\n",
    "# 4. To display an image, first reshape it to 28x28.\n",
    "# pylab.imshow(mnist.train.images[4].reshape((28,28)), cmap=pylab.cm.gray_r)   \n",
    "# pylab.title('Label: %d' % np.argmax(mnist.train.labels[4])) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "NUM_CLASSES = 10\n",
    "NUM_PIXELS = 28 * 28\n",
    "TRAIN_STEPS = 2000\n",
    "BATCH_SIZE = 100\n",
    "LEARNING_RATE = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define inputs\n",
    "images = tf.placeholder(dtype=tf.float32, shape=[None, NUM_PIXELS])\n",
    "labels = tf.placeholder(dtype=tf.float32, shape=[None, NUM_CLASSES])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define model\n",
    "W = tf.Variable(tf.truncated_normal([NUM_PIXELS, NUM_CLASSES]))\n",
    "b = tf.Variable(tf.zeros([NUM_CLASSES]))\n",
    "y = tf.matmul(images, W) + b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define loss and optimizer\n",
    "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=labels))\n",
    "train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Initialize variables after the model is defined\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Train the model\n",
    "for i in range(TRAIN_STEPS):\n",
    "    batch_images, batch_labels = mnist.train.next_batch(BATCH_SIZE)\n",
    "    sess.run(train_step, feed_dict={images: batch_images, labels: batch_labels})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the trained model\n",
    "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(labels, 1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "                                  \n",
    "print(\"Accuracy %f\" % sess.run(accuracy, feed_dict={images: mnist.test.images, \n",
    "                                                    labels: mnist.test.labels}))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As written, this code evaluates the accuracy of the trained model on the entire testing set. Below is a function to predict the label for a single image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction = tf.argmax(y,1)\n",
    "\n",
    "def predict(i):\n",
    "    image = mnist.test.images[i]\n",
    "    actual_label = np.argmax(mnist.test.labels[i])\n",
    "    predicted_label = sess.run(prediction, feed_dict={images: [image]})\n",
    "    return predicted_label, actual_label\n",
    "\n",
    "i = 0\n",
    "predicted, actual = predict(i)\n",
    "print (\"Predicted: %d, actual: %d\" % (predicted, actual))\n",
    "pylab.imshow(mnist.test.images[i].reshape((28,28)), cmap=pylab.cm.gray_r) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
